{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10.12 机器翻译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.2.0 cpu\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import os\n",
    "import io\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import torchtext.vocab as Vocab\n",
    "import torch.utils.data as Data\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "PAD, BOS, EOS = '<pad>', '<bos>', '<eos>'\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print(torch.__version__, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.12.1 读取和预处理数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将一个序列中所有的词记录在all_tokens中以便之后构造词典，然后在该序列后面添加PAD直到序列\n",
    "# 长度变为max_seq_len，然后将序列保存在all_seqs中\n",
    "def process_one_seq(seq_tokens, all_tokens, all_seqs, max_seq_len):\n",
    "    all_tokens.extend(seq_tokens)\n",
    "    seq_tokens += [EOS] + [PAD] * (max_seq_len - len(seq_tokens) - 1)\n",
    "    all_seqs.append(seq_tokens)\n",
    "\n",
    "# 使用所有的词来构造词典。并将所有序列中的词变换为词索引后构造Tensor\n",
    "def build_data(all_tokens, all_seqs):\n",
    "    vocab = Vocab.Vocab(collections.Counter(all_tokens),\n",
    "                        specials=[PAD, BOS, EOS])\n",
    "    indices = [[vocab.stoi[w] for w in seq] for seq in all_seqs]\n",
    "    return vocab, torch.tensor(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data(max_seq_len):\n",
    "    # in和out分别是input和output的缩写\n",
    "    in_tokens, out_tokens, in_seqs, out_seqs = [], [], [], []\n",
    "    with io.open('../../data/fr-en-small.txt') as f:\n",
    "        lines = f.readlines()\n",
    "    for line in lines:\n",
    "        in_seq, out_seq = line.rstrip().split('\\t')\n",
    "        in_seq_tokens, out_seq_tokens = in_seq.split(' '), out_seq.split(' ')\n",
    "        if max(len(in_seq_tokens), len(out_seq_tokens)) > max_seq_len - 1:\n",
    "            continue  # 如果加上EOS后长于max_seq_len，则忽略掉此样本\n",
    "        process_one_seq(in_seq_tokens, in_tokens, in_seqs, max_seq_len)\n",
    "        process_one_seq(out_seq_tokens, out_tokens, out_seqs, max_seq_len)\n",
    "    in_vocab, in_data = build_data(in_tokens, in_seqs)\n",
    "    out_vocab, out_data = build_data(out_tokens, out_seqs)\n",
    "    return in_vocab, out_vocab, Data.TensorDataset(in_data, out_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 5,  4, 45,  3,  2,  0,  0]), tensor([ 8,  4, 27,  3,  2,  0,  0]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_seq_len = 7\n",
    "in_vocab, out_vocab, dataset = read_data(max_seq_len)\n",
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.12.2 含注意力机制的编码器—解码器\n",
    "### 10.12.2.1 编码器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
    "                 drop_prob=0, **kwargs):\n",
    "        super(Encoder, self).__init__(**kwargs)\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=drop_prob)\n",
    "\n",
    "    def forward(self, inputs, state):\n",
    "        # 输入形状是(批量大小, 时间步数)。将输出互换样本维和时间步维\n",
    "        embedding = self.embedding(inputs.long()).permute(1, 0, 2) # (seq_len, batch, input_size)\n",
    "        return self.rnn(embedding, state)\n",
    "\n",
    "    def begin_state(self):\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder = Encoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)\n",
    "output, state = encoder(torch.zeros((4, 7)), encoder.begin_state())\n",
    "output.shape, state.shape # GRU的state是h, 而LSTM的是一个元组(h, c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.12.2.2 注意力机制"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_model(input_size, attention_size):\n",
    "    model = nn.Sequential(nn.Linear(input_size, attention_size, bias=False),\n",
    "                          nn.Tanh(),\n",
    "                          nn.Linear(attention_size, 1, bias=False))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_forward(model, enc_states, dec_state):\n",
    "    \"\"\"\n",
    "    enc_states: (时间步数, 批量大小, 隐藏单元个数)\n",
    "    dec_state: (批量大小, 隐藏单元个数)\n",
    "    \"\"\"\n",
    "    # 将解码器隐藏状态广播到和编码器隐藏状态形状相同后进行连结\n",
    "    dec_states = dec_state.unsqueeze(dim=0).expand_as(enc_states)\n",
    "    enc_and_dec_states = torch.cat((enc_states, dec_states), dim=2)\n",
    "    e = model(enc_and_dec_states)  # 形状为(时间步数, 批量大小, 1)\n",
    "    alpha = F.softmax(e, dim=0)  # 在时间步维度做softmax运算\n",
    "    return (alpha * enc_states).sum(dim=0)  # 返回背景变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_len, batch_size, num_hiddens = 10, 4, 8\n",
    "model = attention_model(2*num_hiddens, 10) \n",
    "enc_states = torch.zeros((seq_len, batch_size, num_hiddens))\n",
    "dec_state = torch.zeros((batch_size, num_hiddens))\n",
    "attention_forward(model, enc_states, dec_state).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.12.2.3 含注意力机制的解码器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
    "                 attention_size, drop_prob=0):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.attention = attention_model(2*num_hiddens, attention_size)\n",
    "        # GRU的输入包含attention输出的c和实际输入, 所以尺寸是 num_hiddens+embed_size\n",
    "        self.rnn = nn.GRU(num_hiddens + embed_size, num_hiddens, \n",
    "                          num_layers, dropout=drop_prob)\n",
    "        self.out = nn.Linear(num_hiddens, vocab_size)\n",
    "\n",
    "    def forward(self, cur_input, state, enc_states):\n",
    "        \"\"\"\n",
    "        cur_input shape: (batch, )\n",
    "        state shape: (num_layers, batch, num_hiddens)\n",
    "        \"\"\"\n",
    "        # 使用注意力机制计算背景向量\n",
    "        c = attention_forward(self.attention, enc_states, state[-1])\n",
    "        # 将嵌入后的输入和背景向量在特征维连结, (批量大小, num_hiddens+embed_size)\n",
    "        input_and_c = torch.cat((self.embedding(cur_input), c), dim=1) \n",
    "        # 为输入和背景向量的连结增加时间步维，时间步个数为1\n",
    "        output, state = self.rnn(input_and_c.unsqueeze(0), state)\n",
    "        # 移除时间步维，输出形状为(批量大小, 输出词典大小)\n",
    "        output = self.out(output).squeeze(dim=0)\n",
    "        return output, state\n",
    "\n",
    "    def begin_state(self, enc_state):\n",
    "        # 直接将编码器最终时间步的隐藏状态作为解码器的初始隐藏状态\n",
    "        return enc_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.12.3 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_loss(encoder, decoder, X, Y, loss):\n",
    "    batch_size = X.shape[0]\n",
    "    enc_state = encoder.begin_state()\n",
    "    enc_outputs, enc_state = encoder(X, enc_state)\n",
    "    # 初始化解码器的隐藏状态\n",
    "    dec_state = decoder.begin_state(enc_state)\n",
    "    # 解码器在最初时间步的输入是BOS\n",
    "    dec_input = torch.tensor([out_vocab.stoi[BOS]] * batch_size)\n",
    "    # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失, 初始全1\n",
    "    mask, num_not_pad_tokens = torch.ones(batch_size,), 0\n",
    "    l = torch.tensor([0.0])\n",
    "    for y in Y.permute(1,0): # Y shape: (batch, seq_len)\n",
    "        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)\n",
    "        l = l + (mask * loss(dec_output, y)).sum()\n",
    "        dec_input = y  # 使用强制教学\n",
    "        num_not_pad_tokens += mask.sum().item()\n",
    "        # EOS后面全是PAD. 下面一行保证一旦遇到EOS接下来的循环中mask就一直是0\n",
    "        mask = mask * (y != out_vocab.stoi[EOS]).float()\n",
    "    return l / num_not_pad_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(encoder, decoder, dataset, lr, batch_size, num_epochs):\n",
    "    enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)\n",
    "    dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)\n",
    "\n",
    "    loss = nn.CrossEntropyLoss(reduction='none')\n",
    "    data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)\n",
    "    for epoch in range(num_epochs):\n",
    "        l_sum = 0.0\n",
    "        for X, Y in data_iter:\n",
    "            enc_optimizer.zero_grad()\n",
    "            dec_optimizer.zero_grad()\n",
    "            l = batch_loss(encoder, decoder, X, Y, loss)\n",
    "            l.backward()\n",
    "            enc_optimizer.step()\n",
    "            dec_optimizer.step()\n",
    "            l_sum += l.item()\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            print(\"epoch %d, loss %.3f\" % (epoch + 1, l_sum / len(data_iter)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, loss 0.475\n",
      "epoch 20, loss 0.245\n",
      "epoch 30, loss 0.157\n",
      "epoch 40, loss 0.052\n",
      "epoch 50, loss 0.039\n"
     ]
    }
   ],
   "source": [
    "embed_size, num_hiddens, num_layers = 64, 64, 2\n",
    "attention_size, drop_prob, lr, batch_size, num_epochs = 10, 0.5, 0.01, 2, 50\n",
    "encoder = Encoder(len(in_vocab), embed_size, num_hiddens, num_layers,\n",
    "                  drop_prob)\n",
    "decoder = Decoder(len(out_vocab), embed_size, num_hiddens, num_layers,\n",
    "                  attention_size, drop_prob)\n",
    "train(encoder, decoder, dataset, lr, batch_size, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.12.4 预测不定长的序列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(encoder, decoder, input_seq, max_seq_len):\n",
    "    in_tokens = input_seq.split(' ')\n",
    "    in_tokens += [EOS] + [PAD] * (max_seq_len - len(in_tokens) - 1)\n",
    "    enc_input = torch.tensor([[in_vocab.stoi[tk] for tk in in_tokens]]) # batch=1\n",
    "    enc_state = encoder.begin_state()\n",
    "    enc_output, enc_state = encoder(enc_input, enc_state)\n",
    "    dec_input = torch.tensor([out_vocab.stoi[BOS]])\n",
    "    dec_state = decoder.begin_state(enc_state)\n",
    "    output_tokens = []\n",
    "    for _ in range(max_seq_len):\n",
    "        dec_output, dec_state = decoder(dec_input, dec_state, enc_output)\n",
    "        pred = dec_output.argmax(dim=1)\n",
    "        pred_token = out_vocab.itos[int(pred.item())]\n",
    "        if pred_token == EOS:  # 当任一时间步搜索出EOS时，输出序列即完成\n",
    "            break\n",
    "        else:\n",
    "            output_tokens.append(pred_token)\n",
    "            dec_input = pred\n",
    "    return output_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['they', 'are', 'watching', '.']"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_seq = 'ils regardent .'\n",
    "translate(encoder, decoder, input_seq, max_seq_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.12.5 评价翻译结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bleu(pred_tokens, label_tokens, k):\n",
    "    len_pred, len_label = len(pred_tokens), len(label_tokens)\n",
    "    score = math.exp(min(0, 1 - len_label / len_pred))\n",
    "    for n in range(1, k + 1):\n",
    "        num_matches, label_subs = 0, collections.defaultdict(int)\n",
    "        for i in range(len_label - n + 1):\n",
    "            label_subs[''.join(label_tokens[i: i + n])] += 1\n",
    "        for i in range(len_pred - n + 1):\n",
    "            if label_subs[''.join(pred_tokens[i: i + n])] > 0:\n",
    "                num_matches += 1\n",
    "                label_subs[''.join(pred_tokens[i: i + n])] -= 1\n",
    "        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))\n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score(input_seq, label_seq, k):\n",
    "    pred_tokens = translate(encoder, decoder, input_seq, max_seq_len)\n",
    "    label_tokens = label_seq.split(' ')\n",
    "    print('bleu %.3f, predict: %s' % (bleu(pred_tokens, label_tokens, k),\n",
    "                                      ' '.join(pred_tokens)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bleu 1.000, predict: they are watching .\n"
     ]
    }
   ],
   "source": [
    "score('ils regardent .', 'they are watching .', k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bleu 0.658, predict: they are exhausted .\n"
     ]
    }
   ],
   "source": [
    "score('ils sont canadienne .', 'they are canadian .', k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
